package com.auditlog.sql.runner;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.auditlog.converter.ConverterRegistry;
import com.auditlog.datasource.ParametersHolder;
import com.auditlog.datasource.StatementProxy;
import com.auditlog.sql.parser.JSqlParserSupport;
import com.auditlog.datasource.struct.RecordImage;
import com.auditlog.datasource.struct.TableRecord;
import com.auditlog.datasource.table.TableMeta;
import com.auditlog.datasource.table.cache.TableMetaCacheFactory;
import com.auditlog.datasource.struct.TableInfo;
import com.auditlog.datasource.table.extrator.TableExtractor;
import com.auditlog.datasource.table.extrator.TableExtractorHolder;
import com.auditlog.exception.NotSupportException;
import com.auditlog.format.RecordImageFormat;
import com.auditlog.format.RecordImageFormatRegistry;
import com.auditlog.util.IndexUtils;
import com.auditlog.util.JdbcConstants;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * sql执行器
 *
 * @author Zhiyang.Zhang
 * @version 1.0
 * @date 2022/11/13 21:05
 */
@Data
@Slf4j
public abstract class AbstractSqlRunner<T extends Statement, S extends net.sf.jsqlparser.statement.Statement> extends JSqlParserSupport implements SqlRunner<S> {
    private StatementProxy<T> statementProxy;
    private S statement;
    private String tableName;
    private String alias;
    private TableMeta tableMeta;

    private static final String DELIMITER = "|`@@`|";

    public AbstractSqlRunner(StatementProxy<T> statementProxy, S statement) {
        this.statementProxy = statementProxy;
        this.statement = statement;
        if (statement != null) {
            TableExtractor extractor = TableExtractorHolder.extractor(this.statement.getClass());
            TableInfo extract = extractor.extract(this.statement);
            this.tableName = extract.getName();
            this.alias = extract.getAlias();
            if (StrUtil.isNotBlank(this.tableName)) {
                if (statementProxy != null) {
                    this.tableMeta = TableMetaCacheFactory.getTableMetaCache(this.statementProxy.getConnectionProxy().getDbType())
                            .getTableMeta(this.statementProxy.getConnectionProxy().getTargetConnection(),
                                    tableName,
                                    this.statementProxy.getConnectionProxy().getDataSourceProxy().getResourceId());
                }
            }
        }
    }

    public boolean isOracle() {
        return this.getDbType().equals(JdbcConstants.ORACLE_STR);
    }

    public boolean isBatch() {
        return this.getStatementProxy().getConnectionProxy().getConnectionContext().isBatch();
    }

    public boolean isPrepareStatement() {
        return this.getStatementProxy() instanceof ParametersHolder;
    }

    public String getDbType() {
        return this.getStatementProxy().getConnectionProxy().getDbType();
    }

    public TableRecord buildTableRecord(List<List<Object>> rows, List<String> primaryKeyWithAlias, List<String> allColumnsWithAlias) {
        List<Integer> indexList = IndexUtils.getIndexList(primaryKeyWithAlias, allColumnsWithAlias);
        List<List<Object>> pkValueList = this.getIndexValue4List(rows, indexList, false);
        int index = 0;
        // 联合主键不能有一个空的，否则会出错
        Map<String, List<Object>> rowMap = new HashMap<>();
        for (List<Object> pkValues : pkValueList) {
            String key = this.assembleKey(primaryKeyWithAlias, pkValues);
            List<Object> row = rows.get(index);
            rowMap.put(key, row);
            index++;
        }
        return TableRecord.builder().delimiter(this.getDelimiter()).rows(rowMap).build();
    }


    /**
     * 填充主键数据（如果存在就被过滤掉，如果不存在就表示是插入）
     *
     * @param image
     * @param ids
     * @return: void
     */
    public void populatePks(RecordImage image, List<Object> ids) {
        boolean insert = true;
        String insertKey = this.assembleKey(image.getPkColumnsName(), ids);
        for (List<Object> possibleUpdatePk : image.getPossibleUpdateOrDeletePks()) {
            String updateKey = this.assembleKey(image.getPkColumnsName(), possibleUpdatePk);
            if (insertKey.equals(updateKey)) {
                insert = false;
                break;
            }
        }
        if (insert) {
            // 插入的数据
            image.getInsertPks().add(ids);
        } else {
            log.debug("产生的key:{},在数据库已存在.", ids);
        }
    }

    public Map<Integer, ArrayList<Object>> getParameters() {
        Map<Integer, ArrayList<Object>> params = new HashMap<>(1);
        if (this.getStatementProxy() instanceof ParametersHolder) {
            ParametersHolder parametersHolder = (ParametersHolder) this.getStatementProxy();
            params = parametersHolder.getParameters();
        }
        return params;
    }

    public PlaceHolderSqlContext getSqlContext(String sql, List<Integer> placeHolderIndex, Map<Integer, ArrayList<Object>> params, TableMeta tableMeta, boolean lock) {
        return PlaceHolderSqlContext.builder()
                .originSql(sql).placeHodlerIndex(placeHolderIndex).sqlParameters(params)
                .lock(lock).pkColumnsNameWithNoAlias(tableMeta.getPkNameListWithAlias(""))
                .tableNameWithNoAlias(tableMeta.getTableName()).selectColumnsWithNoAlias(tableMeta.getColumnNamesByOrder())
                .build();
    }

    public PlaceHolderSqlContext getCurrentSqlContext(String sql, List<Integer> placeHolderIndex, boolean lock) {
        return getSqlContext(sql, placeHolderIndex, this.getParameters(), this.getTableMeta(), lock);
    }

    /**
     * 执行sql并锁定
     *
     * @param holderSqlContext
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> executeParametersHolderSqlWithLockTable(PlaceHolderSqlContext holderSqlContext) throws SQLException {
        holderSqlContext.validate();
        ExecuteSqlContext executeSqlContext = getPlaceHolderSql(holderSqlContext.getOriginSql(), holderSqlContext.getPlaceHodlerIndex(), holderSqlContext.getSqlParameters());
        String lockSql;
        if (holderSqlContext.isLock()) {
            lockSql = this.assembleLockSql(holderSqlContext.getTableNameWithNoAlias(), holderSqlContext.getSelectColumnsWithNoAlias(),
                    holderSqlContext.getPkColumnsNameWithNoAlias(), executeSqlContext.getSql());
        } else {
            lockSql = holderSqlContext.getOriginSql();
        }
        ResultSet resultSet = this.executeSimpleSql(lockSql, executeSqlContext.getParameters());
        return this.assembleResult(resultSet);
    }

    /**
     * 根据批处理添加的参数执行sql
     *
     * @param sql              待执行的sql
     * @param placeHolderIndex sql中占位符下标
     * @param parameters       sql参数
     * @return: java.sql.ResultSet
     */
    public List<List<Object>> executeParametersHolderSql(String sql, List<Integer> placeHolderIndex, Map<Integer, ArrayList<Object>> parameters) throws SQLException {
        ExecuteSqlContext executeSqlContext = getPlaceHolderSql(sql, placeHolderIndex, parameters);
        ResultSet resultSet = this.executeSimpleSql(executeSqlContext.getSql(), executeSqlContext.getParameters(), false);
        return this.assembleResult(resultSet);
    }


    private ExecuteSqlContext getPlaceHolderSql(String orignSql, List<Integer> placeHolderIndex, Map<Integer, ArrayList<Object>> parameters) {
        List<Object> palceHolderParameters = new ArrayList<>();
        if (CollectionUtil.isNotEmpty(placeHolderIndex)) {
            if (CollectionUtil.isEmpty(parameters)) {
                throw new NotSupportException("statement没有绑定可用的参数");
            }
            int batchSize = parameters.get(1).size();
            for (int index = 0; index < batchSize; index++) {
                for (Integer holderIndex : placeHolderIndex) {
                    ArrayList<Object> params = parameters.get(holderIndex);
                    palceHolderParameters.add(params.get(index));
                }
            }
            if (batchSize > 1) {
                StringBuilder sqlBuilder = new StringBuilder();
                sqlBuilder.append(orignSql);
                for (int i = 1; i < batchSize; i++) {
                    sqlBuilder.append(" union ").append(orignSql);
                }
                orignSql = sqlBuilder.toString();
            }
        }
        return ExecuteSqlContext.builder().sql(orignSql).parameters(palceHolderParameters).build();
    }


    /**
     * 执行sql，将返回结果封装到List中
     *
     * @param sql
     * @param parameters
     * @param lockSql
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> executeSimpleSqlWithResult(String sql, List<Object> parameters, boolean lockSql) throws
            SQLException {
        ResultSet resultSet = this.executeSimpleSql(sql, parameters, lockSql);
        return assembleResult(resultSet);
    }

    /**
     * 将所有结果返回，一条数据一个list集合
     *
     * @param resultSet
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> assembleResult(ResultSet resultSet) throws SQLException {
        try {
            int columnCount = resultSet.getMetaData().getColumnCount();
            List<List<Object>> result = new ArrayList<>();
            while (resultSet.next()) {
                List<Object> row = new ArrayList<>();
                for (int i = 1; i <= columnCount; i++) {
                    row.add(resultSet.getObject(i));
                }
                result.add(row);
            }
            return result;
        } finally {
            resultSet.close();
        }
    }

    /**
     * 执行sql
     *
     * @param sql
     * @param parameters
     * @return: java.sql.ResultSet
     */
    public ResultSet executeSimpleSql(String sql, List<Object> parameters) throws SQLException {
        return this.executeSimpleSql(sql, parameters, false);
    }

    /**
     * 执行sql
     *
     * @param sql        待执行的sql
     * @param parameters 参数
     * @param lockSql    是否加锁
     * @return: java.sql.ResultSet
     */
    public ResultSet executeSimpleSql(String sql, List<Object> parameters, boolean lockSql) throws SQLException {
        if (lockSql) {
            sql = sql + " for update";
        }
        log.debug("execute sql: {}", sql);
        ResultSet resultSet;
        if (CollectionUtil.isNotEmpty(parameters)) {
            PreparedStatement preparedStatement = this.getStatementProxy().getConnectionProxy().getTargetConnection().prepareStatement(sql);
            int count = parameters.size();
            for (int index = 1; index <= count; index++) {
                preparedStatement.setObject(index, parameters.get(index - 1));
            }
            resultSet = preparedStatement.executeQuery();
        } else {
            Statement statement = statementProxy.getConnectionProxy().getTargetConnection().createStatement();
            resultSet = statement.executeQuery(sql);
        }
        return resultSet;
    }


    /**
     * 根据主键获取数据
     *
     * @param pkNames       主键列名集合
     * @param pkValues      主键值集合
     * @param selectColumns 要查询的列（一般为全部列）
     * @return: com.auditlog.datasource.struct.TableRecord
     */
    public TableRecord getTableRecord(String tableNameWithAlias, List<String> pkNames, List<List<Object>> pkValues, List<String> selectColumns) throws SQLException {
        String where = assembleWhere(pkNames, pkValues);
        List<String> columns = new ArrayList<>();
        // 还是要查询主键值
        columns.addAll(pkNames);
        // 查询所有列的值
        columns.addAll(selectColumns);
        String select = assembleSelect(columns, tableNameWithAlias);
        StringBuilder sqlBuilder = new StringBuilder();
        sqlBuilder.append(select).append(" where ").append(where);
        String sql = sqlBuilder.toString();
        ResultSet resultSet = null;
        try {
            log.debug("execute sql: {}", sql);
            PreparedStatement preparedStatement = this.getStatementProxy().getConnectionProxy().getTargetConnection().prepareStatement(sql);
            int count = pkNames.size() * pkValues.size();
            for (int i = 1; i <= count; i++) {
                int out = (i - 1) / pkNames.size();
                int in = (i - 1) % pkNames.size();
                preparedStatement.setObject(i, pkValues.get(out).get(in));
            }
            resultSet = preparedStatement.executeQuery();
            int columnCount = resultSet.getMetaData().getColumnCount();
            TableRecord.TableRecordBuilder builder = TableRecord.builder();
            Map<String, List<Object>> columnMap = new HashMap<>();
            while (resultSet.next()) {
                List<Object> pks = new ArrayList<>();
                List<Object> allColumns = new ArrayList<>();
                for (int c = 1; c <= columnCount; c++) {
                    Object object = resultSet.getObject(c);
                    if (c <= pkNames.size()) {
                        pks.add(object);// 联合主键中的某个列也一定不为空？？？
                    } else {
                        allColumns.add(object);
                    }
                }
                String key = assembleKey(pkNames, pks);
                // 主键为key，行记录为value
                columnMap.put(key, allColumns);
            }
            return builder.rows(columnMap).delimiter(this.getDelimiter()).build();
        } finally {
            if (resultSet != null) {
                try {
                    resultSet.close();
                } catch (SQLException throwables) {
                    log.warn("关闭资源异常", throwables);
                }
            }
        }
    }


    /**
     * 将列名和列值组成一个key（还是过滤null值）
     *
     * @param columnName  列名集合
     * @param columnValue 列值集合
     * @return: java.lang.String
     */
    public String assembleKey(List<String> columnName, List<Object> columnValue) {
        return this.assembleKey(columnName, columnValue, this.getDelimiter());
    }

    public String assembleKey(List<String> columnName, List<Object> columnValue, String deli) {
        StringBuilder builder = new StringBuilder();
        int index = columnName.size();
        for (int i = 0; i < index; i++) {
            String keyValue;
            try {
                keyValue = ConverterRegistry.getInstance().convert2Str(columnValue.get(i));
            } catch (SQLException throwables) {
                keyValue = ObjectUtil.toString(columnValue.get(i));
            }
            builder.append(columnName.get(i)).append(deli).append(keyValue);
        }
        return builder.toString();
    }


    public String getDelimiter() {
        return DELIMITER;
    }

    /**
     * 拼接带占位符的where语句，注意这里只是一个片段，不包含where关键字
     *
     * @param columns 条件列
     * @param values  列值
     * @return: java.lang.String
     */
    public String assembleWhere(List<String> columns, List<List<Object>> values) {
        int cycle = values.size();
        StringBuilder builder = new StringBuilder();
        for (int c = 0; c < cycle; c++) {
            if (c == 0) {
                builder.append(" ( ");
            } else {
                builder.append(" or ( ");
            }
            for (int i = 0; i < columns.size(); i++) {
                String ukColumn = columns.get(i);
                if (i != 0) {
                    builder.append(" and ");
                }
                builder.append(ukColumn).append(" = ? ");

                if (i == columns.size() - 1) {
                    builder.append(" ) ");
                }
            }
        }
        return builder.toString();
    }

    /**
     * 组装复杂查询的for update语句
     *
     * @param tableNameWithNoAlias         表名（没有别名）
     * @param selectColumnsNameWithNoAlias 需要查询的列（没有表别名）
     * @param pkColumnsNameWithNoAlias     主键名(没有表别名)
     * @param originSql                    原始sql
     * @return: java.lang.String
     */
    public String assembleLockSql(String tableNameWithNoAlias, List<String> selectColumnsNameWithNoAlias, List<String> pkColumnsNameWithNoAlias, String originSql) {
        String tableAlias = getTempTableName();
        Assert.isTrue(CollectionUtil.isNotEmpty(pkColumnsNameWithNoAlias), "主键不能为空");
        StringBuilder sqlBuilder = new StringBuilder();
        String select = this.assembleSelect(selectColumnsNameWithNoAlias, tableNameWithNoAlias);
        sqlBuilder.append(select).append(" where ");
        sqlBuilder.append(this.list2String(pkColumnsNameWithNoAlias, true));
        sqlBuilder.append(" in ");
        sqlBuilder.append("(").append(" select ").append(this.list2String(pkColumnsNameWithNoAlias, false))
                .append(" from ").append(" ( ").append(originSql).append(" ) ").append(tableAlias)
                .append(" ) for update");

        return sqlBuilder.toString();
    }

    public String getTempTableName() {
        return "t" + System.currentTimeMillis() + " ";
    }

    public String list2String(List<String> stringList, boolean containsBrackets) {
        StringBuilder builder = new StringBuilder();
        if (containsBrackets) {
            builder.append("(");
        }
        int index = 0;
        for (String pkColumn : stringList) {
            builder.append(pkColumn);
            if (index != stringList.size() - 1) {
                builder.append(",");
            }
            index++;
        }
        if (containsBrackets) {
            builder.append(")");
        }
        return builder.toString();
    }


    /**
     * 获取对应下标值的集合
     *
     * @param columnsValue 行数据集合
     * @param columnsIndex 要提取值对应的下标
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> getIndexValue4List(List<List<Object>> columnsValue, List<Integer> columnsIndex, boolean containsNull) {
        List<List<Object>> objects = new ArrayList<>();
        // 只有在insert的列中包含这些键值的时候才执行
        for (List<Object> value : columnsValue) {
            List<Object> indexValueList = indexValue4List(value, columnsIndex, containsNull);
            if (indexValueList.size() > 0) {
                objects.add(indexValueList);
            }
        }
        return objects;
    }


    public List<Object> indexValue4List(List<Object> columnValues, List<Integer> columnsIndex,
                                        boolean containsNull) {
        // 只有在insert的列中包含这些键值的时候才执行
        List<Object> vObject = new ArrayList<>();
        for (Integer iindex : columnsIndex) {
            Object o = columnValues.get(iindex - 1);
            if (o == null && !containsNull) {
                // 因为只要插入的唯一键有空值，那么唯一键键或者主键检测就不会是唯一的，就一定为插入了
                vObject.clear();
                break;
            } else {
                vObject.add(o);
            }
        }
        return vObject;
    }


    /**
     * 合并集合
     *
     * @param target  目标集合
     * @param columns 被添加的集合
     * @param exclude 是否排除相同值
     * @return: void
     */
    public void combineColumns(List<String> target, List<String> columns, boolean exclude) {
        if (CollectionUtil.isNotEmpty(columns)) {
            for (String column : columns) {
                if (target.contains(column) && !exclude || !target.contains(column)) {
                    target.add(column);
                }
            }
        }
    }


    /**
     * 拼接select语句
     *
     * @param columns   要查询的列
     * @param tableName 表名
     * @return: java.lang.String
     */
    public String assembleSelect(List<String> columns, String tableName) {
        return columns.stream().collect(Collectors.joining(",", " select ", " from " + tableName + " "));
    }

    public void formatRecordImage(RecordImage recordImage) {
        if (recordImage == null) {
            return;
        }
        RecordImageFormat recordImageFormat = RecordImageFormatRegistry.getInstance().getRecordImageFormat(recordImage.getSqlType());
        recordImageFormat.format(recordImage);
    }


    public void handleInsertAfterRecordImage(RecordImage recordImage) throws SQLException {
        TableRecord tableRecord = TableRecord.builder().delimiter(this.getDelimiter()).build();
        if (CollectionUtil.isNotEmpty(recordImage.getInsertPks())) {
            tableRecord = this.getTableRecord(recordImage.getTableName(), recordImage.getPkColumnsName(), recordImage.getInsertPks(), recordImage.getAllColumnsWithAlias());
        }
        recordImage.setAfterRecord(tableRecord);
    }

    public void handleDeleteAfterRecordImage(RecordImage recordImage) {
        //do nothing
        recordImage.setAfterRecord(TableRecord.builder().delimiter(this.getDelimiter()).build());
    }

    public void handleUpdateAfterRecordImage(RecordImage recordImage) throws SQLException {
        TableRecord tableRecord = TableRecord.builder().delimiter(this.getDelimiter()).build();
        if (CollectionUtil.isNotEmpty(recordImage.getPossibleUpdateOrDeletePks())) {
            tableRecord = this.getTableRecord(recordImage.getTableName(), recordImage.getPkColumnsName(), recordImage.getPossibleUpdateOrDeletePks(), recordImage.getAllColumnsWithAlias());
        }
        recordImage.setAfterRecord(tableRecord);
        if (recordImage.getBeforeRecord().getAffectedRows() != recordImage.getAfterRecord().getAffectedRows()) {
            throw new IllegalStateException("更新前后数据不一致");
        }
    }

    @Override
    public String getAlias() {
        return this.alias;
    }

    @Override
    public String getTableName() {
        return this.tableName;
    }

    @Override
    public String getTableNameWithAlias() {
        return StrUtil.isBlank(this.getAlias()) ? this.getTableName() : this.getTableName() + " " + this.getAlias();
    }

    @Builder
    @Data
    static class PlaceHolderSqlContext {
        private String tableNameWithNoAlias;
        private List<String> pkColumnsNameWithNoAlias;
        private List<String> selectColumnsWithNoAlias;
        private String originSql;
        private List<Integer> placeHodlerIndex;
        private boolean lock;
        private Map<Integer, ArrayList<Object>> sqlParameters;


        public void validate() {
            if (lock) {
                if (StrUtil.isEmpty(tableNameWithNoAlias) || CollectionUtil.isEmpty(pkColumnsNameWithNoAlias) || CollectionUtil.isEmpty(selectColumnsWithNoAlias)) {
                    throw new IllegalArgumentException("必须指定表名、主键名和待查询的列");
                }
            }
            if (CollectionUtil.isNotEmpty(placeHodlerIndex) && CollectionUtil.isEmpty(sqlParameters)) {
                throw new IllegalArgumentException("sql中包含占位符，但是没有指定参数值");
            }
        }
    }

    @Data
    @Builder
    static class ExecuteSqlContext {
        private String sql;
        private List<Object> parameters;
    }
}

